// Copyright Epic Games, Inc. All Rights Reserved.

#include <zennet/statsdclient.h>

#include <zencore/logging.h>
#include <zencore/string.h>
#include <zencore/testing.h>
#include <zencore/thread.h>

#include <deque>

ZEN_THIRD_PARTY_INCLUDES_START
#include <zencore/windows.h>
#include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END

namespace zen {

StatsTransportBase::~StatsTransportBase() = default;
StatsDaemonClient::~StatsDaemonClient()	  = default;

//////////////////////////////////////////////////////////////////////////

class StatsdUdpClient : public StatsTransportBase
{
	using udp = asio::ip::udp;

public:
	StatsdUdpClient(std::string_view TargetHost, uint16_t TargetPort) : m_TargetHost(TargetHost), m_TargetPort(TargetPort)
	{
		InitializeSocket();
	}
	~StatsdUdpClient() { Shutdown(); }

	virtual void SendMessage(const void* Data, size_t Size) override
	{
		if (m_Socket)
		{
			std::error_code Ec;
			m_Socket->send_to(asio::buffer(Data, Size), m_ServerAddr, 0, Ec);
		}
	}

private:
	asio::io_service m_IoService;
	std::string		 m_TargetHost;
	uint16_t		 m_TargetPort;

	udp::endpoint				 m_ServerAddr;
	std::unique_ptr<udp::socket> m_Socket;

	void InitializeSocket();
	void Shutdown();
};

void
StatsdUdpClient::InitializeSocket()
{
	std::error_code Ec;
	m_Socket = std::make_unique<udp::socket>(m_IoService);
	m_Socket->open(udp::v4(), Ec);

	if (Ec)
	{
		ZEN_WARN("StatsdUdpClient::InitializeSocket socket creation failed: {}", Ec.message());

		return m_Socket.reset();
	}

	udp::resolver			DnsResolver(m_IoService);
	udp::resolver::query	DnsQuery(udp::v4(), m_TargetHost, fmt::format("{}", m_TargetPort));
	udp::resolver::iterator It = DnsResolver.resolve(DnsQuery, Ec);

	if (Ec)
	{
		ZEN_WARN("StatsdUdpClient::InitializeSocket resolve of '{}:{}' failed: {}", m_TargetHost, m_TargetPort, Ec.message());

		return m_Socket.reset();
	}

	m_ServerAddr = *It;
}

void
StatsdUdpClient::Shutdown()
{
	if (m_Socket)
	{
		m_Socket->shutdown(asio::socket_base::shutdown_send);
		m_Socket.reset();
	}
}

//////////////////////////////////////////////////////////////////////////

class StatsdMemoryClient : public StatsTransportBase
{
	using MessageBuffer_t = std::vector<uint8_t>;

public:
	StatsdMemoryClient() {}
	~StatsdMemoryClient() {}

	virtual void SendMessage(const void* Data, size_t Size) override
	{
		RwLock::ExclusiveLockScope _(m_Lock);
		m_Messages.emplace_back(reinterpret_cast<const uint8_t*>(Data), reinterpret_cast<const uint8_t*>(Data) + Size);
	}

	std::deque<MessageBuffer_t> GetMessages()
	{
		RwLock::ExclusiveLockScope _(m_Lock);
		return std::move(m_Messages);
	}

private:
	RwLock						m_Lock;
	std::deque<MessageBuffer_t> m_Messages;
};

//////////////////////////////////////////////////////////////////////////

class StatsMessageBuilder
{
	using udp = asio::ip::udp;

public:
	StatsMessageBuilder(StatsTransportBase& Transport) : m_Transport(Transport) { m_AcceptMessages.test_and_set(); }

	~StatsMessageBuilder()
	{
		if (m_MessagingThread.joinable())
		{
			m_AcceptMessages.clear();
			m_ShutdownRequested.Set();
			m_MessagingThread.join();
		}
		else
		{
			FlushMessageBuffer();
		}
	}

	void SetUseBackgroundThread(bool UseBackgroundThread) { m_UseBackgroundThread = UseBackgroundThread; }

	void SetMessageSize(size_t MessageSize)
	{
		m_MessageSize = MessageSize;

		if (m_UseBackgroundThread && !m_MessagingThread.joinable())
		{
			m_MessagingThread = std::thread([this] {
				SetCurrentThreadName("statsd_reporter");

				ZEN_INFO("statsd reporting thread started");

				bool ShutdownRequested = false;

				do
				{
					ShutdownRequested = m_ShutdownRequested.Wait(100);

					FlushAll();
				} while (!ShutdownRequested);

				FlushAll();

				ZEN_INFO("statsd reporting thread exiting");
			});
		}
	}

	void SendMetric(std::string_view MetricMessage)
	{
		if (m_AcceptMessages.test() == false)
		{
			return;
		}

		if (MetricMessage.size() >= m_MessageSize)
		{
			return SendMessage(MetricMessage);
		}

		RwLock::ExclusiveLockScope _(m_MessageLock);

		if ((m_MessageBuffer.size() + MetricMessage.size()) >= m_MessageSize)
		{
			// This will clear the message buffer
			EnqueueOrSendCurrentMessage(_);
		}

		if (!m_MessageBuffer.empty())
		{
			m_MessageBuffer.push_back(uint8_t('\n'));
		}

		m_MessageBuffer.insert(m_MessageBuffer.end(),
							   reinterpret_cast<const uint8_t*>(MetricMessage.data()),
							   reinterpret_cast<const uint8_t*>(MetricMessage.data() + MetricMessage.size()));
	}

	void FlushMessageBuffer();
	void FlushQueue();

	void FlushAll()
	{
		FlushQueue();
		FlushMessageBuffer();
	}

private:
	StatsTransportBase&	 m_Transport;
	std::thread			 m_MessagingThread;
	Event				 m_ShutdownRequested;
	std::atomic_flag	 m_AcceptMessages;
	RwLock				 m_MessageLock;
	std::vector<uint8_t> m_MessageBuffer;
	size_t				 m_MessageSize		   = 0;
	bool				 m_UseBackgroundThread = true;

	std::deque<std::vector<uint8_t>> m_MessageQueue;

	void EnqueueOrSendCurrentMessage(RwLock::ExclusiveLockScope&)
	{
		if (m_MessagingThread.joinable())
		{
			m_MessageQueue.emplace_back(std::move(m_MessageBuffer));
			m_MessageBuffer.reserve(m_MessageSize);
		}
		else
		{
			m_Transport.SendMessage(m_MessageBuffer.data(), m_MessageBuffer.size());
			m_MessageBuffer.clear();
		}
	}

	void SendMessage(std::string_view Message) { m_Transport.SendMessage(Message.data(), Message.size()); }
};

void
StatsMessageBuilder::FlushMessageBuffer()
{
	RwLock::ExclusiveLockScope _(m_MessageLock);

	if (m_MessageBuffer.empty())
		return;

	m_Transport.SendMessage(m_MessageBuffer.data(), m_MessageBuffer.size());
	m_MessageBuffer.clear();
}

void
StatsMessageBuilder::FlushQueue()
{
	std::deque<std::vector<uint8_t>> Queue;

	{
		RwLock::ExclusiveLockScope _(m_MessageLock);

		std::swap(Queue, m_MessageQueue);
	}

	while (!Queue.empty())
	{
		std::vector<uint8_t>& Message = Queue.front();

		m_Transport.SendMessage(Message.data(), Message.size());

		Queue.pop_front();
	}
}

//////////////////////////////////////////////////////////////////////////

class StatsDaemonClientImpl : public StatsDaemonClient
{
public:
	StatsDaemonClientImpl(std::unique_ptr<StatsTransportBase>&& Transport);
	~StatsDaemonClientImpl();

	virtual void Decrement(std::string_view Metric) override;
	virtual void Increment(std::string_view Metric) override;
	virtual void Count(std::string_view Metric, int64_t CountDelta) override;
	virtual void Gauge(std::string_view Metric, uint64_t CurrentValue) override;
	virtual void Meter(std::string_view Metric, uint64_t IncrementValue) override;
	virtual void SetMessageSize(size_t MessageSize, bool UseThreads) override;
	virtual void Flush() override;

private:
	std::unique_ptr<StatsTransportBase> m_Transport;
	StatsMessageBuilder					m_MessageBuilder;
	size_t								m_MessageSize = 0;
};

std::unique_ptr<StatsDaemonClient>
CreateStatsDaemonClient(std::string_view TargetHost, uint16_t TargetPort)
{
	return std::make_unique<StatsDaemonClientImpl>(std::make_unique<StatsdUdpClient>(TargetHost, TargetPort));
}

std::unique_ptr<StatsDaemonClient>
CreateStatsDaemonClient(std::unique_ptr<StatsTransportBase>&& Transport)
{
	return std::make_unique<StatsDaemonClientImpl>(std::move(Transport));
}

//////////////////////////////////////////////////////////////////////////

StatsDaemonClientImpl::StatsDaemonClientImpl(std::unique_ptr<StatsTransportBase>&& Transport)
: m_Transport(std::move(Transport))
, m_MessageBuilder(*m_Transport)
{
}

StatsDaemonClientImpl::~StatsDaemonClientImpl()
{
}

void
StatsDaemonClientImpl::Increment(std::string_view Metric)
{
	ExtendableStringBuilder<128> MetricMessage;
	MetricMessage << Metric << ":1|c";
	m_MessageBuilder.SendMetric(MetricMessage);
}

void
StatsDaemonClientImpl::Decrement(std::string_view Metric)
{
	ExtendableStringBuilder<128> MetricMessage;
	MetricMessage << Metric << ":-1|c";
	m_MessageBuilder.SendMetric(MetricMessage);
}

void
StatsDaemonClientImpl::Count(std::string_view Metric, int64_t CountDelta)
{
	ExtendableStringBuilder<128> MetricMessage;
	MetricMessage << Metric << ":" << CountDelta << "|c";
	m_MessageBuilder.SendMetric(MetricMessage);
}

void
StatsDaemonClientImpl::Gauge(std::string_view Metric, uint64_t CurrentValue)
{
	ExtendableStringBuilder<128> MetricMessage;
	MetricMessage << Metric << ":" << CurrentValue << "|g";
	m_MessageBuilder.SendMetric(MetricMessage);
}

void
StatsDaemonClientImpl::Meter(std::string_view Metric, uint64_t IncrementValue)
{
	ExtendableStringBuilder<128> MetricMessage;
	MetricMessage << Metric << ":" << IncrementValue << "|m";
	m_MessageBuilder.SendMetric(MetricMessage);
}

void
StatsDaemonClientImpl::SetMessageSize(size_t MessageSize, bool UseThreads)
{
	m_MessageBuilder.SetUseBackgroundThread(UseThreads);
	m_MessageBuilder.SetMessageSize(MessageSize);
}

void
StatsDaemonClientImpl::Flush()
{
	m_MessageBuilder.FlushAll();
}

//////////////////////////////////////////////////////////////////////////

#if ZEN_WITH_TESTS

void
statsd_forcelink()
{
}

TEST_CASE("zennet.statsd.emit")
{
	//	auto Client = CreateStatsDaemonClient("localhost", 8125);
	auto				MemoryClient	= std::make_unique<StatsdMemoryClient>();
	StatsdMemoryClient* RawMemoryClient = MemoryClient.get();
	auto				Client			= CreateStatsDaemonClient(std::move(MemoryClient));

	Client->Count("test.counter1", 1);
	Client->Count("test.counter2", 2);
	Client->Meter("test.meter1", 42);
	Client->Meter("test.meter2", 42);
	Client->Meter("test.meter2", 42);
	Client->Increment("test.count");
	Client->Decrement("test.count");
	Client->Increment("test.count");
	Client->Gauge("test.gauge", 1);
	Client->Gauge("test.gauge", 99);

	const std::string_view FormattedMessages[] = {"test.counter1:1|c",
												  "test.counter2:2|c",
												  "test.meter1:42|m",
												  "test.meter2:42|m",
												  "test.meter2:42|m",
												  "test.count:1|c",
												  "test.count:-1|c",
												  "test.count:1|c",
												  "test.gauge:1|g",
												  "test.gauge:99|g"};

	auto Messages = RawMemoryClient->GetMessages();

	CHECK_EQ(Messages.size(), 10);

	for (int i = 0; i < 10; ++i)
	{
		std::string_view Message(reinterpret_cast<const char*>(Messages[i].data()), Messages[i].size());
		CHECK_EQ(Message, FormattedMessages[i]);
	}
}

TEST_CASE("zennet.statsd.batch")
{
	//	auto Client = CreateStatsDaemonClient("localhost", 8125);
	auto				MemoryClient	= std::make_unique<StatsdMemoryClient>();
	StatsdMemoryClient* RawMemoryClient = MemoryClient.get();
	auto				Client			= CreateStatsDaemonClient(std::move(MemoryClient));

	const int  MaxMessageSize = 1000;
	const bool UseThreads	  = false;
	Client->SetMessageSize(MaxMessageSize, UseThreads);

	Client->Count("test.counter1", 1);
	Client->Count("test.counter2", 2);
	Client->Meter("test.meter1", 42);
	Client->Meter("test.meter2", 42);
	Client->Meter("test.meter2", 42);
	Client->Increment("test.count");
	Client->Decrement("test.count");
	Client->Increment("test.count");
	Client->Gauge("test.gauge", 1);
	Client->Gauge("test.gauge", 99);

	for (int i = 0; i < 1000; ++i)
	{
		Client->Increment("test.count1000");
	}

	Client->Flush();

	auto Messages = RawMemoryClient->GetMessages();

	CHECK_EQ(Messages.size(), 20);

	for (const auto& Message : Messages)
	{
		CHECK(Message.size() <= MaxMessageSize);
	}
}

#endif

}  // namespace zen
